{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "02d28c91",
   "metadata": {},
   "source": [
    "# TTS Aligner Inference\n",
    "\n",
    "In this notebook, we will walk through how to perform inference on a **[RAD-TTS Aligner](https://arxiv.org/abs/2108.10447)** checkpoint. This tutorial will cover everything from preprocessing input text and audio to generating token duration predictions and alignments. We will be visualizing and examining these steps as we go.\n",
    "\n",
    "We will also show an example of how you can use the alignments generated by the text/audio embeddings to perform **phoneme disambiguation** of a word with multiple possible pronunciations.\n",
    "\n",
    "This tutorial requires an already-trained Aligner checkpoint and a sample from [LJSpeech](https://keithito.com/LJ-Speech-Dataset/). Once an NGC checkpoint is released, it will be updated to use that by default. You should also be able to substitute in your own model checkpoint and samples with the code shown, if you wish.\n",
    "\n",
    "## License\n",
    "\n",
    "> Copyright 2022 NVIDIA. All Rights Reserved.\n",
    ">\n",
    "> Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
    ">\n",
    "> `http://www.apache.org/licenses/LICENSE-2.0`\n",
    ">\n",
    "> Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df295cfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "You can either run this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "\"\"\"\n",
    "BRANCH = 'main'\n",
    "# # If you're using Colab and not running locally, uncomment and run this cell.\n",
    "# !apt-get install sox libsndfile1 ffmpeg\n",
    "# !pip install wget text-unidecode\n",
    "# !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d53c8e4",
   "metadata": {},
   "source": [
    "We'll need to import some libraries for loading audio, plotting various data, and of course for loading the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d463d1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Start with some imports so we can visualize alignments and load the checkpoint\n",
    "%matplotlib inline\n",
    "import matplotlib.pylab as plt\n",
    "import IPython.display as ipd\n",
    "\n",
    "import librosa\n",
    "import soundfile as sf\n",
    "import torch\n",
    "\n",
    "from nemo.collections.tts.models import AlignerModel"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d30138d",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "Let's start by loading the checkpoint from NGC. You can find the model card [here](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/tts_en_radtts_aligner)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "896fee82",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set device (GPU or CPU)\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Load the ARPABET Aligner model checkpoint\n",
    "aligner = AlignerModel.from_pretrained(\"tts_en_radtts_aligner\")\n",
    "\n",
    "# This should be set to whatever sample rate your model was trained on\n",
    "target_sr = 22050"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c62d2786",
   "metadata": {},
   "source": [
    "Now we'll load an audio file and input the corresponding transcript. The audio file will be resampled to the `target_sr` given above.\n",
    "\n",
    "This example uses the first sample from the NVIDIA test split of [LJSpeech](https://keithito.com/LJ-Speech-Dataset/), which is file `LJ023-0089.wav`. You can use whatever you'd like, of course, but this tutorial will refer to this sample specifically for a concrete example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d070c2e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://multilangaudiosamples.s3.us-east-2.amazonaws.com/LJ023-0089.wav"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f3bbaf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This tutorial uses a sample from the NVIDIA test split of LJSpeech.\n",
    "audio_path = \"./LJ023-0089.wav\"\n",
    "text_raw = \"That is not only my accusation.\"\n",
    "\n",
    "# Load audio and resample if necessary\n",
    "audio_data, orig_sr = sf.read(audio_path)\n",
    "if orig_sr != target_sr:\n",
    "    audio_data = librosa.core.resample(audio_data, orig_sr=orig_sr, target_sr=target_sr)\n",
    "\n",
    "# Let's double-check that everything matches up!\n",
    "print(f\"Duration (s): {len(audio_data)/target_sr}\")\n",
    "print(\"Transcript:\")\n",
    "print(text_raw)\n",
    "ipd.Audio(audio_data, rate=target_sr)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c686a4c6",
   "metadata": {},
   "source": [
    "### Audio Preprocessing\n",
    "\n",
    "The Aligner model takes in a mel spectrogram as input, so we'll need to convert our audio signal before we can evaluate it. The trained model has a preprocessor that will do this for us once we find the audio data length."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c36111a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Retrieve audio length for the model's preprocessor\n",
    "audio_len = torch.tensor(audio_data.shape[0], device=device).long()\n",
    "\n",
    "# Need to unsqueeze the audio data and audio_len to simulate a batch size of 1\n",
    "audio = torch.tensor(audio_data, dtype=torch.float, device=device).unsqueeze(0)\n",
    "audio_len = torch.tensor(audio_len).unsqueeze(0)\n",
    "print(f\"Audio batch shape: {audio.shape}\")\n",
    "print(f\"Audio length shape: {audio_len.shape}\\n\")\n",
    "\n",
    "# Generate the spectrogram!\n",
    "spec, spec_len = aligner.preprocessor(input_signal=audio, length=audio_len)\n",
    "print(f\"Spec batch shape: {spec.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8644dc8d",
   "metadata": {},
   "source": [
    "Let's take a look at the spectrogram to make sure it's been loaded correctly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74640183",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the spectrogram\n",
    "plt.figure(figsize=(15,5))\n",
    "_ = plt.pcolormesh(spec[0].cpu().numpy(), cmap='viridis')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9e9002e",
   "metadata": {},
   "source": [
    "If the above looks like a spectrogram, we can move on to text preprocessing."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bfd458b",
   "metadata": {},
   "source": [
    "### Text Preprocessing\n",
    "\n",
    "Now, we need to preprocess the text to be passed in to the model. This involves normalization, as well as conversion of the words in the transcript to phonemes where possible. OOV words, as well as words with multiple pronunciations, are ignored and kept as graphemes.\n",
    "\n",
    "Let's take a look at these steps, one at a time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e30a707",
   "metadata": {},
   "outputs": [],
   "source": [
    "# First, a standard English normalization of the text.\n",
    "# We set punct_post_process=True to preserve words with apostrophes, otherwise they get split.\n",
    "text_normalized = aligner.normalizer.normalize(text_raw, punct_post_process=True)\n",
    "print(text_normalized)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77ab17e4",
   "metadata": {},
   "source": [
    "At this point, we could normally just run the normalized text through the model's `tokenizer`, which would run G2P (grapheme to phoneme) conversion and spit out text tokens to pass into the model directly. But just to illustrate what happens within the tokenizer, let's take a look at its G2P step.\n",
    "\n",
    "*(If you are writing your own inference script, you can leave the code in this next cell out entirely, as it's purely illustrative.)*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43e6bab0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The intermediate G2P step.\n",
    "# This part is usually hidden behind just calling `tokenizer()`, but we show it here to illustrate what happens.\n",
    "text_g2p = aligner.tokenizer.g2p(text_normalized)\n",
    "print(text_g2p)\n",
    "print(f\"Length: {len(text_g2p)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2724fab",
   "metadata": {},
   "source": [
    "We can see that some words have been converted to phonemes (e.g. \"not\" turned into `[\"N\", \"AA1\", \"T\"]`), while some have stayed as graphemes (e.g. \"that\" is still `[\"t\", \"h\", \"a\", \"t\"]`). As mentioned above, this is because any words with unique and known pronunciations are converted, but other words may have multiple possible pronunciations. CMUdict lists three for \"that\": `\"DH AE1 T\"`/`\"DH AH0 T\"`.\n",
    "\n",
    "The next cell shows what we'd normally run right after we normalize the text. This gets us our text tokens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5954746c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The tokenizer runs G2P and then encodes each token.\n",
    "text_tokens = aligner.tokenizer(text_normalized)\n",
    "print(text_tokens)\n",
    "print(f\"Length: {len(text_tokens)}\")\n",
    "\n",
    "# We need these to be torch tensors with a batch dimension before passing them in as input, of course\n",
    "text = torch.tensor(text_tokens, device=device).unsqueeze(0).long()\n",
    "text_len = torch.tensor(len(text_tokens), device=device).unsqueeze(0).long()\n",
    "print(\"\\nAfter unsqueezing...\")\n",
    "print(f\"Text input shape: {text.shape}\")\n",
    "print(f\"Text length shape: {text_len.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f4dc62f",
   "metadata": {},
   "source": [
    "The length increases by 2 if `pad_with_space` was set for the model, which it was for this checkpoint. For ease of lining the results up later, let's update `text_g2p` to reflect this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76dc57bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Update text_g2p with spaces\n",
    "text_g2p.insert(0, ' ')\n",
    "text_g2p.insert(len(text_g2p), ' ')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8fa0e741",
   "metadata": {},
   "source": [
    "Now we have our audio data and encoded text!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e112828",
   "metadata": {},
   "source": [
    "## Inference: Alignments and Phoneme Disambiguation\n",
    "\n",
    "Now that we have the audio and tokenized text, we can pass it through the trained model and get an alignment between the two inputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f0d6817",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run the aligner!\n",
    "with torch.no_grad():\n",
    "    attn_soft_tensor, attn_logprob_tensor = aligner(spec=spec, spec_len=spec_len, text=text, text_len=text_len)\n",
    "\n",
    "# \"Unbatch\" the results\n",
    "attn_soft = attn_soft_tensor[0, 0, :, :].data.cpu().numpy()\n",
    "attn_logprob = attn_logprob_tensor[0, 0, :, :].data.cpu().numpy()\n",
    "\n",
    "print(f\"Dimensions should be (spec_len={spec_len[0].data}, text_len={text_len[0].data}) for both:\")\n",
    "print(f\"Soft attention matrix shape: {attn_soft.shape}\")\n",
    "print(f\"Log prob matrix shape: {attn_logprob.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4cb9dd3b",
   "metadata": {},
   "source": [
    "### Visualizing the Alignments\n",
    "\n",
    "Now that we have the soft alignments, we can take a look at how the model matches up text tokens and audio input based on the attention matrix generated. This should roughly be a **monotonically decreasing diagonal line** (towards the bottom right).\n",
    "\n",
    "In the following cell, we transpose the **soft attention matrix** before plotting it in order to show it more \"naturally,\" that is, with the text along the vertical edge (Y-axis) and an increase in the X-axis (left-to-right) value corresponding with moving forward in time through the spectrogram."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f967826",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize soft attention matrix.\n",
    "fig, ax = plt.subplots(figsize=(12,5))\n",
    "_ = ax.imshow(attn_soft.transpose(), origin='upper', aspect='auto')\n",
    "_ = ax.set_yticks(range(len(text_g2p)))\n",
    "_ = ax.set_yticklabels(text_g2p)  # To show the text labels"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3647908d",
   "metadata": {},
   "source": [
    "The above is a soft attention matrix, so we can see that it is somewhat noisy.\n",
    "\n",
    "We can calculate a **hard attention matrix** to get more concrete predictions for the durations of each grapheme/phoneme. The next plot should be much sharper. We'll show the spectrogram again so we get a rough idea of what alignments match up with what spectrogram features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9f4674f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import helper function to calculate hard attention\n",
    "from nemo.collections.tts.parts.utils.helpers import binarize_attention\n",
    "\n",
    "attn_hard_tensor = binarize_attention(attn_soft_tensor, text_len, spec_len)\n",
    "attn_hard = attn_hard_tensor[0, 0, :, :].data.cpu().numpy()\n",
    "print(f\"Hard attention matrix shape: {attn_hard.shape}\")  # This should be the same as the soft attn matrix shape!\n",
    "\n",
    "# Now, let's plot the hard attention matrix.\n",
    "fig, ax = plt.subplots(2, 1, figsize=(12,10))\n",
    "_ = ax[0].imshow(attn_hard.transpose(), origin='upper', aspect='auto')\n",
    "_ = ax[0].set_yticks(range(len(text_g2p)))\n",
    "_ = ax[0].set_yticklabels(text_g2p)  # To show the text labels\n",
    "\n",
    "# This is the same spectrogram as before, but we show it here just for comparison\n",
    "_ = ax[1].pcolormesh(spec[0].cpu().numpy(), cmap='viridis')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "015d0571",
   "metadata": {},
   "source": [
    "### Calculating Token Durations\n",
    "\n",
    "To get the duration (in frames) of each token, we would get the hard attention matrix, then sum up the number of frames that correspond to each token. Luckily, there is a function in the Aligner's encoder module that does this for us!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "364c7f94",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Call function to calculate each token's duration in frames\n",
    "durations = aligner.alignment_encoder.get_durations(attn_soft_tensor, text_len, spec_len).int()\n",
    "\n",
    "# Let's match them up. (We strip out the first and last duration due to zero-padding.)\n",
    "durations_sum = 0\n",
    "for t,d in zip(text_g2p, durations[0]):\n",
    "    print(f\"'{t}' duration: {d}\")\n",
    "    durations_sum += d\n",
    "\n",
    "# The following should be equal.\n",
    "print(f\"Total number of frames: {spec_len.item()}\")\n",
    "print(f\"Sum of durations: {durations_sum}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7053ffc6",
   "metadata": {},
   "source": [
    "### Phoneme Disambiguation via Embedding Distance\n",
    "\n",
    "Remember how some words were not converted in the G2P step, and were kept as graphemes because they had multiple possible pronunciations? It turns out that we can also use a trained Aligner model to make predictions for phoneme disambiguation!\n",
    "\n",
    "We can do this by:\n",
    "\n",
    "1. Generating **one text input per possible pronunciation** (e.g. one sentence with `\"DH AE1 T\"` and one with `\"DH AH0 T\"`)\n",
    "2. **Running inference** on each (with the same spectrogram)\n",
    "3. Calculating the **distance between the text/spectrogram embeddings**\n",
    "4. Seeing **which disambiguation tokens are closer to the spectrogram**, as determined by the model.\n",
    "\n",
    "---\n",
    "\n",
    "Let's get started with our example! As a reminder, the original sentence we've used for this tutorial is:\n",
    "```\n",
    "That is not only my accusation.\n",
    "```\n",
    "\n",
    "In this sentence, \"that,\" \"is,\" and \"accusation\" have multiple entries in CMUdict. Each has two possible pronunciations, which means to disambiguate everything, we'd use six inputs:\n",
    "```\n",
    "# Disambiguate \"that\":\n",
    "DH AE1 T   i s   N AA1 T   OW1 N L IY0   M AY1   a c c u s a t i o n .\n",
    "DH AH0 T   i s   N AA1 T   OW1 N L IY0   M AY1   a c c u s a t i o n .\n",
    "\n",
    "# Disambiguate \"is\":\n",
    "t h a t   IH1 Z   N AA1 T   OW1 N L IY0   M AY1   a c c u s a t i o n .\n",
    "t h a t   IH0 Z   N AA1 T   OW1 N L IY0   M AY1   a c c u s a t i o n .\n",
    "\n",
    "# Disambiguate \"accusation\":\n",
    "t h a t   i s   N AA1 T   OW1 N L IY0   M AY1   AE2 K Y AH0 Z EY1 SH AH0 N .\n",
    "t h a t   i s   N AA1 T   OW1 N L IY0   M AY1   AE2 K Y UW0 Z EY1 SH AH0 N .\n",
    "```\n",
    "\n",
    "For brevity's sake, let's just disambiguate the word `that`. To create our two candidate inputs, we'll use the `text_g2p` that we generated earlier but cut out the letters from \"that\" and replace them with our possible pronunciations, then run them through the `EnglishPhonemesTokenizer`'s `encode_from_g2p()` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a419bb2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "### (1) Generate one text input per possible pronunciation\n",
    "\n",
    "# Construct our two candidate sentences by replacing \"t\" \"h\" \"a\" \"t\" with two phonemic possibilities\n",
    "that1 = [\"DH\", \"AE1\", \"T\"]\n",
    "that2 = [\"DH\", \"AH0\", \"T\"]\n",
    "pron1_g2p = that1 + text_g2p[5:-1] # Chop off trailing space, the tokenizer will add it\n",
    "pron2_g2p = that2 + text_g2p[5:-1] # Ditto.\n",
    "print(\"=== Text ===\")\n",
    "print(pron1_g2p)\n",
    "print(pron2_g2p)\n",
    "\n",
    "# Tokenize!\n",
    "pron1_tokens = aligner.tokenizer.encode_from_g2p(pron1_g2p)\n",
    "pron2_tokens = aligner.tokenizer.encode_from_g2p(pron2_g2p)\n",
    "print(\"\\n=== Tokens===\")\n",
    "print(pron1_tokens)\n",
    "print(pron2_tokens)\n",
    "\n",
    "# Create a batch\n",
    "disamb_text = torch.tensor([pron1_tokens, pron2_tokens], device=device).long()\n",
    "disamb_text_len = torch.tensor([len(pron1_tokens), len(pron2_tokens)], device=device).long()\n",
    "print(\"\\n=== Text/Text Length Tensor Shapes ===\")\n",
    "print(disamb_text.shape)\n",
    "print(disamb_text_len.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49485889",
   "metadata": {},
   "source": [
    "And again, we'll insert a space at the beginning and a space at the end because `pad_with_space` is set to True in the tokenizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2a7422c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Insert spaces to reflect the padded token vector\n",
    "pron1_g2p.insert(0, ' ')\n",
    "pron1_g2p.insert(len(pron1_g2p), ' ')\n",
    "print(len(pron1_g2p))\n",
    "\n",
    "pron2_g2p.insert(0, ' ')\n",
    "pron2_g2p.insert(len(pron2_g2p), ' ')\n",
    "print(len(pron2_g2p))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e27a2b8",
   "metadata": {},
   "source": [
    "Note that we have picked an example where both disambiguations have the same tokenized length (which will be the case most of the time). If you have a case where the two pronunciations have different lengths, you may need to perform some padding to get the batch to line up.\n",
    "\n",
    "Let's run inference on the new inputs. These two text inputs are candidates for the same spectrogram, so we'll duplicate the spectrogram input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ba2fd16",
   "metadata": {},
   "outputs": [],
   "source": [
    "### (2) Run inference on each candidate\n",
    "\n",
    "# Duplicate spec and spec_len to match the two text inputs\n",
    "spec_2 = spec.repeat([2, 1, 1])\n",
    "spec_len_2 = spec_len.repeat([2])\n",
    "\n",
    "# Inference with two inputs\n",
    "with torch.no_grad():\n",
    "    disamb_attn_soft_tensor, _ = aligner(\n",
    "        spec=spec_2,\n",
    "        spec_len=spec_len_2,\n",
    "        text=disamb_text,\n",
    "        text_len=disamb_text_len\n",
    "    )\n",
    "\n",
    "# \"Unbatch\" the results\n",
    "disamb_attn_soft = disamb_attn_soft_tensor[:, 0, :, :].data.cpu().numpy()\n",
    "print(f\"Dimensions should be (2, spec_len={spec_len_2[0].data}, text_len={max(disamb_text_len.data)}):\")\n",
    "print(f\"Soft attention matrix shape: {disamb_attn_soft.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdb21299",
   "metadata": {},
   "source": [
    "Next, we retrieve the L2 distance matrix between each text embedding and its corresponding spectrogram embedding. There is an alignment encoder function called `get_dist()` that will calculate $(\\texttt{text_emb[i]} - \\texttt{spec_emb[j]})^2$ for all pairs of text tokens and spectrogram timesteps, and we can get the L2 distance matrix by square-rooting those values.\n",
    "\n",
    "(Note that darker = smaller distance, so we should see a dark diagonal of a similar shape to the lines above.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea145449",
   "metadata": {},
   "outputs": [],
   "source": [
    "### (3) Calculate distance between text and spectrogram embeddings for each candidate\n",
    "\n",
    "# Housekeeping: we first need to get the text embedding from the Aligner encoder\n",
    "disamb_text_embs = aligner.embed(disamb_text).transpose(1,2)\n",
    "\n",
    "# Run the Aligner encoder to get the distances between the key (text) and query (spectrogram) embeddings.\n",
    "square_dists = aligner.alignment_encoder.get_dist(keys=disamb_text_embs, queries=spec_2)\n",
    "l2_dists = square_dists.sqrt()\n",
    "\n",
    "# We can plot the L2 distances now\n",
    "l2_dists_data = l2_dists.data.cpu().numpy()\n",
    "fig, ax = plt.subplots(2, 1, figsize=(12,10))\n",
    "\n",
    "# Here, we trim the first and last time steps (the zero-padding)\n",
    "_ = ax[0].imshow(l2_dists_data[0, 1:-1].transpose(), origin='upper', aspect='auto')\n",
    "_ = ax[0].set_yticks(range(len(pron1_g2p)))\n",
    "_ = ax[0].set_yticklabels(pron1_g2p)  # To show the text labels\n",
    "_ = ax[0].set_title(\"\\\"DH AE1 T\\\" Candidate - Embedding L2 Distance Matrix\")\n",
    "\n",
    "_ = ax[1].imshow(l2_dists_data[1, 1:-1].transpose(), origin='upper', aspect='auto')\n",
    "_ = ax[1].set_yticks(range(len(pron2_g2p)))\n",
    "_ = ax[1].set_yticklabels(pron2_g2p)\n",
    "_ = ax[1].set_title(\"\\\"DH AH0 T\\\" Candidate - Embedding L2 Distance Matrix\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28e083d5",
   "metadata": {},
   "source": [
    "The last step is to calculate the average distance between the text tokens for \"that\" and their corresponding audio frames. **We expect that the candidate pronunciation that's the closest to the audio should be the most representative of the actual speech.**\n",
    "\n",
    "To do this, we need to get each token's durations, which will let us isolate only the (predicted) frames that correspond to `DH AE1 T` and `DH AH0 T` respectively. Then, the Aligner's encoder has a function called `get_mean_distance_for_word()` that will calculate the average distance over the frames corresponding only to the tokens in the word."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f929f6df",
   "metadata": {},
   "outputs": [],
   "source": [
    "### (4) Check which disambiguation is closer to the speech\n",
    "\n",
    "# Get durations like before; the batch size of 2 shouldn't change how we call the function.\n",
    "disamb_durations = aligner.alignment_encoder.get_durations(\n",
    "    disamb_attn_soft_tensor,\n",
    "    disamb_text_len,\n",
    "    spec_len_2\n",
    ").int()\n",
    "\n",
    "# Retrieve the average embedding distances for each pronunciation of \"that\"\n",
    "that1_mean_dist = aligner.alignment_encoder.get_mean_distance_for_word(\n",
    "    l2_dists=l2_dists[0],\n",
    "    durs=disamb_durations[0],\n",
    "    start_token=1,  # Remember to account for space padding\n",
    "    num_tokens=len(that1)\n",
    ")\n",
    "that2_mean_dist = aligner.alignment_encoder.get_mean_distance_for_word(\n",
    "    l2_dists=l2_dists[1],\n",
    "    durs=disamb_durations[1],\n",
    "    start_token=1,  # Here as well\n",
    "    num_tokens=len(that2)\n",
    ")\n",
    "\n",
    "print(f\"Average distance for {that1}: {that1_mean_dist}\")\n",
    "print(f\"Average distance for {that2}: {that2_mean_dist}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b336cc0",
   "metadata": {},
   "source": [
    "And we're done!\n",
    "\n",
    "**With the average distance for `DH AE1 T` being about 377 and the average distance for `DH AH0 T` being about 403, we can pick `DH AE1 T` as the better match.**\n",
    "\n",
    "As an exercise, try editing the blocks of code above to disambiguate \"accusation\" (`AE2 K Y AH0 Z EY1 SH AH0 N` versus `AE2 K Y UW0 Z EY1 SH AH0 N`)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c802394c",
   "metadata": {},
   "source": [
    "#### Addendum\n",
    "\n",
    "As a side note, there is also a function called `get_mean_dist_by_durations()` that will match up the distances between each token and its corresponding spectrogram frame (using the previously-calculated durations), then calculate the mean over the batch.\n",
    "\n",
    "The whole-sentence average may not tell us very much here because we just want to know which pronunciation is closest to what's being said for a specific word, but it's there if you need it as an extra metric!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c138a66f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mean distance over the batches:\n",
    "mean_dists = aligner.alignment_encoder.get_mean_dist_by_durations(\n",
    "    dist=l2_dists.to('cpu'),\n",
    "    durations=disamb_durations.to('cpu')\n",
    ")\n",
    "print(mean_dists)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "842c3667",
   "metadata": {},
   "source": [
    "## Resources\n",
    "\n",
    "- For more information about the Aligner architecture, check out the [RAD-TTS Aligner paper](https://arxiv.org/abs/2108.10447).\n",
    "- If you would like to run disambiguation on a large batch of sentences, try out the [Aligner disambiguation example script](https://github.com/NVIDIA/NeMo/blob/main/examples/tts/aligner_heteronym_disambiguation.py)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d86bf46",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
